package kv

import (
	"context"
	"crypto/rand"
	"encoding/binary"
	mathRand "math/rand"
	"strconv"
	"testing"
	"time"

	"github.com/OffchainLabs/prysm/v7/beacon-chain/state"
	"github.com/OffchainLabs/prysm/v7/cmd/beacon-chain/flags"
	"github.com/OffchainLabs/prysm/v7/config/features"
	fieldparams "github.com/OffchainLabs/prysm/v7/config/fieldparams"
	"github.com/OffchainLabs/prysm/v7/config/params"
	"github.com/OffchainLabs/prysm/v7/consensus-types/blocks"
	"github.com/OffchainLabs/prysm/v7/consensus-types/interfaces"
	"github.com/OffchainLabs/prysm/v7/consensus-types/primitives"
	"github.com/OffchainLabs/prysm/v7/encoding/bytesutil"
	"github.com/OffchainLabs/prysm/v7/genesis"
	"github.com/OffchainLabs/prysm/v7/math"
	enginev1 "github.com/OffchainLabs/prysm/v7/proto/engine/v1"
	ethpb "github.com/OffchainLabs/prysm/v7/proto/prysm/v1alpha1"
	"github.com/OffchainLabs/prysm/v7/runtime/version"
	"github.com/OffchainLabs/prysm/v7/testing/assert"
	"github.com/OffchainLabs/prysm/v7/testing/require"
	"github.com/OffchainLabs/prysm/v7/testing/util"
	logTest "github.com/sirupsen/logrus/hooks/test"
	bolt "go.etcd.io/bbolt"
)

func TestStateNil(t *testing.T) {
	db := setupDB(t)
	_, err := db.StateOrError(t.Context(), [32]byte{})
	require.ErrorIs(t, err, ErrNotFoundState)
}

func TestState_CanSaveRetrieve(t *testing.T) {
	type testCase struct {
		name     string
		s        func() state.BeaconState
		rootSeed byte
	}

	cases := []testCase{
		{
			name: "phase0",
			s: func() state.BeaconState {
				st, err := util.NewBeaconState()
				require.NoError(t, err)
				require.NoError(t, st.SetSlot(100))
				return st
			},
			rootSeed: '0',
		},
		{
			name: "altair",
			s: func() state.BeaconState {
				st, err := util.NewBeaconStateAltair()
				require.NoError(t, err)
				require.NoError(t, st.SetSlot(100))
				return st
			},
			rootSeed: 'A',
		},
		{
			name: "bellatrix",
			s: func() state.BeaconState {
				st, err := util.NewBeaconStateBellatrix()
				require.NoError(t, err)
				require.NoError(t, st.SetSlot(100))
				p, err := blocks.WrappedExecutionPayloadHeader(&enginev1.ExecutionPayloadHeader{
					ParentHash:       make([]byte, 32),
					FeeRecipient:     make([]byte, 20),
					StateRoot:        make([]byte, 32),
					ReceiptsRoot:     make([]byte, 32),
					LogsBloom:        make([]byte, 256),
					PrevRandao:       make([]byte, 32),
					ExtraData:        []byte("foo"),
					BaseFeePerGas:    make([]byte, 32),
					BlockHash:        make([]byte, 32),
					TransactionsRoot: make([]byte, 32),
				})
				require.NoError(t, err)
				require.NoError(t, st.SetLatestExecutionPayloadHeader(p))
				return st
			},
			rootSeed: 'B',
		},
		{
			name: "capella",
			s: func() state.BeaconState {
				st, err := util.NewBeaconStateCapella()
				require.NoError(t, err)
				require.NoError(t, st.SetSlot(100))
				p, err := blocks.WrappedExecutionPayloadHeaderCapella(&enginev1.ExecutionPayloadHeaderCapella{
					ParentHash:       make([]byte, 32),
					FeeRecipient:     make([]byte, 20),
					StateRoot:        make([]byte, 32),
					ReceiptsRoot:     make([]byte, 32),
					LogsBloom:        make([]byte, 256),
					PrevRandao:       make([]byte, 32),
					ExtraData:        []byte("foo"),
					BaseFeePerGas:    make([]byte, 32),
					BlockHash:        make([]byte, 32),
					TransactionsRoot: make([]byte, 32),
					WithdrawalsRoot:  make([]byte, 32),
				})
				require.NoError(t, err)
				require.NoError(t, st.SetLatestExecutionPayloadHeader(p))
				return st
			},
			rootSeed: 'C',
		},
		{
			name: "deneb",
			s: func() state.BeaconState {
				st, err := util.NewBeaconStateDeneb()
				require.NoError(t, err)
				require.NoError(t, st.SetSlot(100))
				p, err := blocks.WrappedExecutionPayloadHeaderDeneb(&enginev1.ExecutionPayloadHeaderDeneb{
					ParentHash:       make([]byte, 32),
					FeeRecipient:     make([]byte, 20),
					StateRoot:        make([]byte, 32),
					ReceiptsRoot:     make([]byte, 32),
					LogsBloom:        make([]byte, 256),
					PrevRandao:       make([]byte, 32),
					ExtraData:        []byte("foo"),
					BaseFeePerGas:    make([]byte, 32),
					BlockHash:        make([]byte, 32),
					TransactionsRoot: make([]byte, 32),
					WithdrawalsRoot:  make([]byte, 32),
				})
				require.NoError(t, err)
				require.NoError(t, st.SetLatestExecutionPayloadHeader(p))
				return st
			},
			rootSeed: 'D',
		},
		{
			name: "electra",
			s: func() state.BeaconState {
				st, err := util.NewBeaconStateElectra()
				require.NoError(t, err)
				require.NoError(t, st.SetSlot(100))
				p, err := blocks.WrappedExecutionPayloadHeaderDeneb(&enginev1.ExecutionPayloadHeaderDeneb{
					ParentHash:       make([]byte, 32),
					FeeRecipient:     make([]byte, 20),
					StateRoot:        make([]byte, 32),
					ReceiptsRoot:     make([]byte, 32),
					LogsBloom:        make([]byte, 256),
					PrevRandao:       make([]byte, 32),
					ExtraData:        []byte("foo"),
					BaseFeePerGas:    make([]byte, 32),
					BlockHash:        make([]byte, 32),
					TransactionsRoot: make([]byte, 32),
					WithdrawalsRoot:  make([]byte, 32),
				})
				require.NoError(t, err)
				require.NoError(t, st.SetLatestExecutionPayloadHeader(p))
				return st
			},
			rootSeed: 'E',
		},
	}

	db := setupDB(t)

	for _, enableFlag := range []bool{true, false} {
		reset := features.InitWithReset(&features.Flags{EnableHistoricalSpaceRepresentation: enableFlag})

		for _, tc := range cases {
			t.Run(tc.name+" - EnableHistoricalSpaceRepresentation is "+strconv.FormatBool(enableFlag), func(t *testing.T) {
				rootNonce := byte('0')
				if enableFlag {
					rootNonce = '1'
				}
				root := bytesutil.ToBytes32([]byte{tc.rootSeed, rootNonce})
				require.Equal(t, false, db.HasState(t.Context(), root))
				st := tc.s()

				require.NoError(t, db.SaveState(t.Context(), st, root))
				assert.Equal(t, true, db.HasState(t.Context(), root))

				savedSt, err := db.State(t.Context(), root)
				require.NoError(t, err)

				assert.DeepSSZEqual(t, st.ToProtoUnsafe(), savedSt.ToProtoUnsafe())
			})
		}

		reset()
	}
}

func TestState_CanSaveRetrieveValidatorEntries(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe(), "saved state with validators and retrieved state are not matching")

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)
}

func TestStateAltair_CanSaveRetrieveValidatorEntries(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, _ := util.DeterministicGenesisStateAltair(t, 20)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe(), "saved state with validators and retrieved state are not matching")

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)
}

func TestState_CanSaveRetrieveValidatorEntriesFromCache(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))

	// check if the state is in cache
	for i := range stateValidators {
		hash, hashErr := stateValidators[i].HashTreeRoot()
		assert.NoError(t, hashErr)

		data, ok := db.validatorEntryCache.Get(hash[:])
		assert.Equal(t, true, ok)
		require.NotNil(t, data)

		require.DeepSSZEqual(t, stateValidators[i], data, "validator entry is not matching")
	}

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)

}

func TestState_CanSaveRetrieveValidatorEntriesWithoutCache(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))
	db.validatorEntryCache.Clear()

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe(), "saved state with validators and retrieved state are not matching")

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)

}

func TestState_DeleteState(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r1 := [32]byte{'A'}
	r2 := [32]byte{'B'}

	require.Equal(t, false, db.HasState(t.Context(), r1))
	require.Equal(t, false, db.HasState(t.Context(), r2))

	// create two states with the same set of validators.
	stateValidators := validators(10)
	st1, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st1.SetSlot(100))
	require.NoError(t, st1.SetValidators(stateValidators))

	st2, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st2.SetSlot(101))
	require.NoError(t, st2.SetValidators(stateValidators))

	// save both the states.
	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st1, r1))
	require.NoError(t, db.SaveState(ctx, st2, r2))

	// delete the first state.
	var deleteBlockRoots [][32]byte
	deleteBlockRoots = append(deleteBlockRoots, r1)
	require.NoError(t, db.DeleteStates(ctx, deleteBlockRoots))

	// check if the validator entries of this state is removed from cache.
	for _, val := range stateValidators {
		hash, hashErr := val.HashTreeRoot()
		assert.NoError(t, hashErr)
		v, found := db.validatorEntryCache.Get(hash[:])
		require.Equal(t, false, found)
		require.IsNil(t, v)
	}

	// check if the index of the first state is deleted.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r1[:])
		require.Equal(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r2[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)
}

func TestGenesisState_CanSaveRetrieve(t *testing.T) {
	db := setupDB(t)

	headRoot := [32]byte{'B'}

	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(1))
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), headRoot))
	genesis.StoreStateDuringTest(t, st)

	savedGenesisS, err := db.GenesisState(t.Context())
	require.NoError(t, err)
	assert.DeepSSZEqual(t, st.ToProtoUnsafe(), savedGenesisS.ToProtoUnsafe(), "Did not retrieve saved state")
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), [32]byte{'C'}))
}

func TestStore_StatesBatchDelete(t *testing.T) {
	db := setupDB(t)
	ctx := t.Context()
	numBlocks := 100
	totalBlocks := make([]interfaces.ReadOnlySignedBeaconBlock, numBlocks)
	blockRoots := make([][32]byte, 0)
	evenBlockRoots := make([][32]byte, 0)
	for i := range totalBlocks {
		b := util.NewBeaconBlock()
		b.Block.Slot = primitives.Slot(i)
		var err error
		totalBlocks[i], err = blocks.NewSignedBeaconBlock(b)
		require.NoError(t, err)
		r, err := totalBlocks[i].Block().HashTreeRoot()
		require.NoError(t, err)
		st, err := util.NewBeaconState()
		require.NoError(t, err)
		require.NoError(t, st.SetSlot(primitives.Slot(i)))
		require.NoError(t, db.SaveState(t.Context(), st, r))
		blockRoots = append(blockRoots, r)
		if i%2 == 0 {
			evenBlockRoots = append(evenBlockRoots, r)
		}
	}
	require.NoError(t, db.SaveBlocks(ctx, totalBlocks))
	// We delete all even indexed states.
	require.NoError(t, db.DeleteStates(ctx, evenBlockRoots))
	// When we retrieve the data, only the odd indexed state should remain.
	for _, r := range blockRoots {
		s, err := db.State(t.Context(), r)
		require.NoError(t, err)
		if s == nil {
			continue
		}
		assert.Equal(t, primitives.Slot(1), s.Slot()%2, "State with slot %d should have been deleted", s.Slot())
	}
}

func TestStore_DeleteGenesisState(t *testing.T) {
	db := setupDB(t)
	ctx := t.Context()

	genesisBlockRoot := [32]byte{'A'}
	require.NoError(t, db.SaveGenesisBlockRoot(ctx, genesisBlockRoot))
	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, db.SaveState(ctx, st, genesisBlockRoot))
	wantedErr := "cannot delete finalized block or state"
	assert.ErrorContains(t, wantedErr, db.DeleteState(ctx, genesisBlockRoot))
}

func TestStore_DeleteFinalizedState(t *testing.T) {
	db := setupDB(t)
	ctx := t.Context()

	genesis := bytesutil.ToBytes32([]byte{'G', 'E', 'N', 'E', 'S', 'I', 'S'})
	require.NoError(t, db.SaveGenesisBlockRoot(ctx, genesis))

	blk := util.NewBeaconBlock()
	blk.Block.ParentRoot = genesis[:]
	blk.Block.Slot = 100

	wsb, err := blocks.NewSignedBeaconBlock(blk)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(ctx, wsb))

	finalizedBlockRoot, err := blk.Block.HashTreeRoot()
	require.NoError(t, err)

	finalizedState, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, finalizedState.SetSlot(100))
	require.NoError(t, db.SaveState(ctx, finalizedState, finalizedBlockRoot))
	finalizedCheckpoint := &ethpb.Checkpoint{Root: finalizedBlockRoot[:]}
	require.NoError(t, db.SaveFinalizedCheckpoint(ctx, finalizedCheckpoint))
	wantedErr := "cannot delete finalized block or state"
	assert.ErrorContains(t, wantedErr, db.DeleteState(ctx, finalizedBlockRoot))
}

func TestStore_DeleteHeadState(t *testing.T) {
	db := setupDB(t)
	ctx := t.Context()

	genesis := bytesutil.ToBytes32([]byte{'G', 'E', 'N', 'E', 'S', 'I', 'S'})
	require.NoError(t, db.SaveGenesisBlockRoot(ctx, genesis))

	blk := util.NewBeaconBlock()
	blk.Block.ParentRoot = genesis[:]
	blk.Block.Slot = 100
	wsb, err := blocks.NewSignedBeaconBlock(blk)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(ctx, wsb))

	headBlockRoot, err := blk.Block.HashTreeRoot()
	require.NoError(t, err)
	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, db.SaveState(ctx, st, headBlockRoot))
	require.NoError(t, db.SaveHeadBlockRoot(ctx, headBlockRoot))
	require.NoError(t, db.DeleteState(ctx, headBlockRoot)) // Ok to delete head state if it's optimistic.
}

func TestStore_SaveDeleteState_CanGetHighestBelow(t *testing.T) {
	db := setupDB(t)

	b := util.NewBeaconBlock()
	b.Block.Slot = 1
	r, err := b.Block.HashTreeRoot()
	require.NoError(t, err)
	wsb, err := blocks.NewSignedBeaconBlock(b)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(t.Context(), wsb))
	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(1))
	s0 := st.ToProtoUnsafe()
	require.NoError(t, db.SaveState(t.Context(), st, r))

	b.Block.Slot = 100
	r1, err := b.Block.HashTreeRoot()
	require.NoError(t, err)
	wsb, err = blocks.NewSignedBeaconBlock(b)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(t.Context(), wsb))
	st, err = util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(100))
	s1 := st.ToProtoUnsafe()
	require.NoError(t, db.SaveState(t.Context(), st, r1))

	b.Block.Slot = 1000
	r2, err := b.Block.HashTreeRoot()
	require.NoError(t, err)
	wsb, err = blocks.NewSignedBeaconBlock(b)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(t.Context(), wsb))
	st, err = util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(1000))
	s2 := st.ToProtoUnsafe()

	require.NoError(t, db.SaveState(t.Context(), st, r2))

	highest, err := db.HighestSlotStatesBelow(t.Context(), 2)
	require.NoError(t, err)
	assert.DeepSSZEqual(t, highest[0].ToProtoUnsafe(), s0)

	highest, err = db.HighestSlotStatesBelow(t.Context(), 101)
	require.NoError(t, err)
	assert.DeepSSZEqual(t, highest[0].ToProtoUnsafe(), s1)

	highest, err = db.HighestSlotStatesBelow(t.Context(), 1001)
	require.NoError(t, err)
	assert.DeepSSZEqual(t, highest[0].ToProtoUnsafe(), s2)
}

func TestStore_GenesisState_CanGetHighestBelow(t *testing.T) {
	db := setupDB(t)

	genesisState, err := util.NewBeaconState()
	require.NoError(t, err)
	genesisRoot := [32]byte{'a'}
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), genesisRoot))
	genesis.StoreStateDuringTest(t, genesisState)

	b := util.NewBeaconBlock()
	b.Block.Slot = 1
	r, err := b.Block.HashTreeRoot()
	require.NoError(t, err)
	wsb, err := blocks.NewSignedBeaconBlock(b)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(t.Context(), wsb))

	st, err := util.NewBeaconState()
	require.NoError(t, err)
	require.NoError(t, st.SetSlot(1))
	require.NoError(t, db.SaveState(t.Context(), st, r))

	highest, err := db.HighestSlotStatesBelow(t.Context(), 2)
	require.NoError(t, err)
	assert.DeepSSZEqual(t, highest[0].ToProtoUnsafe(), st.ToProtoUnsafe())

	highest, err = db.HighestSlotStatesBelow(t.Context(), 1)
	require.NoError(t, err)
	assert.DeepSSZEqual(t, highest[0].ToProtoUnsafe(), genesisState.ToProtoUnsafe())
	highest, err = db.HighestSlotStatesBelow(t.Context(), 0)
	require.NoError(t, err)
	assert.DeepSSZEqual(t, highest[0].ToProtoUnsafe(), genesisState.ToProtoUnsafe())
}

func TestStore_CleanUpDirtyStates_AboveThreshold(t *testing.T) {
	db := setupDB(t)

	genesisState, err := util.NewBeaconState()
	require.NoError(t, err)
	genesisRoot := [32]byte{'a'}
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), genesisRoot))
	require.NoError(t, db.SaveState(t.Context(), genesisState, genesisRoot))
	require.NoError(t, db.SaveOriginCheckpointBlockRoot(t.Context(), [32]byte{'a'}))

	bRoots := make([][32]byte, 0)
	slotsPerArchivedPoint := primitives.Slot(128)
	prevRoot := genesisRoot
	for i := primitives.Slot(1); i <= slotsPerArchivedPoint; i++ {
		b := util.NewBeaconBlock()
		b.Block.Slot = i
		b.Block.ParentRoot = prevRoot[:]
		r, err := b.Block.HashTreeRoot()
		require.NoError(t, err)
		wsb, err := blocks.NewSignedBeaconBlock(b)
		require.NoError(t, err)
		require.NoError(t, db.SaveBlock(t.Context(), wsb))
		bRoots = append(bRoots, r)
		prevRoot = r

		st, err := util.NewBeaconState()
		require.NoError(t, err)
		require.NoError(t, st.SetSlot(i))
		require.NoError(t, db.SaveState(t.Context(), st, r))
	}

	require.NoError(t, db.SaveFinalizedCheckpoint(t.Context(), &ethpb.Checkpoint{
		Root:  bRoots[len(bRoots)-1][:],
		Epoch: primitives.Epoch(slotsPerArchivedPoint / params.BeaconConfig().SlotsPerEpoch),
	}))
	require.NoError(t, db.CleanUpDirtyStates(t.Context(), slotsPerArchivedPoint))

	for i, root := range bRoots {
		if primitives.Slot(i) >= slotsPerArchivedPoint.SubSlot(slotsPerArchivedPoint.Div(3)) {
			require.Equal(t, true, db.HasState(t.Context(), root))
		} else {
			require.Equal(t, false, db.HasState(t.Context(), root))
		}
	}
}

func TestStore_CleanUpDirtyStates_Finalized(t *testing.T) {
	db := setupDB(t)

	genesisState, err := util.NewBeaconState()
	require.NoError(t, err)
	genesisRoot := [32]byte{'a'}
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), genesisRoot))
	require.NoError(t, db.SaveState(t.Context(), genesisState, genesisRoot))
	require.NoError(t, db.SaveOriginCheckpointBlockRoot(t.Context(), [32]byte{'a'}))

	for i := primitives.Slot(1); i <= params.BeaconConfig().SlotsPerEpoch; i++ {
		b := util.NewBeaconBlock()
		b.Block.Slot = i
		r, err := b.Block.HashTreeRoot()
		require.NoError(t, err)
		wsb, err := blocks.NewSignedBeaconBlock(b)
		require.NoError(t, err)
		require.NoError(t, db.SaveBlock(t.Context(), wsb))

		st, err := util.NewBeaconState()
		require.NoError(t, err)
		require.NoError(t, st.SetSlot(i))
		require.NoError(t, db.SaveState(t.Context(), st, r))
	}

	require.NoError(t, db.SaveFinalizedCheckpoint(t.Context(), &ethpb.Checkpoint{Root: genesisRoot[:]}))
	require.NoError(t, db.CleanUpDirtyStates(t.Context(), params.BeaconConfig().SlotsPerEpoch))
	require.Equal(t, true, db.HasState(t.Context(), genesisRoot))
}

func TestStore_CleanUpDirtyStates_OriginRoot(t *testing.T) {
	db := setupDB(t)

	genesisState, err := util.NewBeaconState()
	require.NoError(t, err)
	r := [32]byte{'a'}
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), r))
	require.NoError(t, db.SaveState(t.Context(), genesisState, r))

	for i := primitives.Slot(1); i <= params.BeaconConfig().SlotsPerEpoch; i++ {
		b := util.NewBeaconBlock()
		b.Block.Slot = i
		r, err := b.Block.HashTreeRoot()
		require.NoError(t, err)
		wsb, err := blocks.NewSignedBeaconBlock(b)
		require.NoError(t, err)
		require.NoError(t, db.SaveBlock(t.Context(), wsb))

		st, err := util.NewBeaconState()
		require.NoError(t, err)
		require.NoError(t, st.SetSlot(i))
		require.NoError(t, db.SaveState(t.Context(), st, r))
	}

	require.NoError(t, db.SaveOriginCheckpointBlockRoot(t.Context(), r))
	require.NoError(t, db.CleanUpDirtyStates(t.Context(), params.BeaconConfig().SlotsPerEpoch))
	require.Equal(t, true, db.HasState(t.Context(), r))
}

func TestStore_CleanUpDirtyStates_DontDeleteNonFinalized(t *testing.T) {
	db := setupDB(t)

	genesisState, err := util.NewBeaconState()
	require.NoError(t, err)
	genesisRoot := [32]byte{'a'}
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), genesisRoot))
	require.NoError(t, db.SaveState(t.Context(), genesisState, genesisRoot))
	require.NoError(t, db.SaveOriginCheckpointBlockRoot(t.Context(), [32]byte{'a'}))

	var unfinalizedRoots [][32]byte
	for i := primitives.Slot(1); i <= params.BeaconConfig().SlotsPerEpoch; i++ {
		b := util.NewBeaconBlock()
		b.Block.Slot = i
		r, err := b.Block.HashTreeRoot()
		require.NoError(t, err)
		wsb, err := blocks.NewSignedBeaconBlock(b)
		require.NoError(t, err)
		require.NoError(t, db.SaveBlock(t.Context(), wsb))
		unfinalizedRoots = append(unfinalizedRoots, r)

		st, err := util.NewBeaconState()
		require.NoError(t, err)
		require.NoError(t, st.SetSlot(i))
		require.NoError(t, db.SaveState(t.Context(), st, r))
	}

	require.NoError(t, db.SaveFinalizedCheckpoint(t.Context(), &ethpb.Checkpoint{Root: genesisRoot[:]}))
	require.NoError(t, db.CleanUpDirtyStates(t.Context(), params.BeaconConfig().SlotsPerEpoch))

	for _, rt := range unfinalizedRoots {
		require.Equal(t, true, db.HasState(t.Context(), rt))
	}
}

func TestAltairState_CanSaveRetrieve(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateAltair(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe())

	savedS, err = db.State(t.Context(), [32]byte{'B'})
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestAltairState_CanDelete(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateAltair(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	require.NoError(t, db.DeleteState(t.Context(), r))
	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func validators(limit int) []*ethpb.Validator {
	var vals []*ethpb.Validator
	for i := range limit {
		pubKey := make([]byte, params.BeaconConfig().BLSPubkeyLength)
		binary.LittleEndian.PutUint64(pubKey, mathRand.Uint64())
		val := &ethpb.Validator{
			PublicKey:                  pubKey,
			WithdrawalCredentials:      bytesutil.ToBytes(mathRand.Uint64(), 32),
			EffectiveBalance:           mathRand.Uint64(),
			Slashed:                    i%2 != 0,
			ActivationEligibilityEpoch: primitives.Epoch(mathRand.Uint64()),
			ActivationEpoch:            primitives.Epoch(mathRand.Uint64()),
			ExitEpoch:                  primitives.Epoch(mathRand.Uint64()),
			WithdrawableEpoch:          primitives.Epoch(mathRand.Uint64()),
		}
		vals = append(vals, val)
	}
	return vals
}

func checkStateSaveTime(b *testing.B, saveCount int) {

	db := setupDB(b)
	initialSetOfValidators := validators(100000)

	// construct some states and save to randomize benchmark.
	for range saveCount {
		key := make([]byte, 32)
		_, err := rand.Read(key)
		require.NoError(b, err)
		st, err := util.NewBeaconState()
		require.NoError(b, err)

		// Add some more new validator to the base validator.
		validatosToAddInTest := validators(10000)
		allValidators := append(initialSetOfValidators, validatosToAddInTest...)

		// shuffle validators.
		mathRand.New(mathRand.NewSource(time.Now().UnixNano()))
		mathRand.Shuffle(len(allValidators), func(i, j int) { allValidators[i], allValidators[j] = allValidators[j], allValidators[i] })

		require.NoError(b, st.SetValidators(allValidators))
		require.NoError(b, db.SaveState(b.Context(), st, bytesutil.ToBytes32(key)))
	}

	// create a state to save in benchmark
	r := [32]byte{'A'}
	st, err := util.NewBeaconState()
	require.NoError(b, err)
	require.NoError(b, st.SetValidators(initialSetOfValidators))

	b.ReportAllocs()

	for b.Loop() {
		require.NoError(b, db.SaveState(b.Context(), st, r))
	}
}

func checkStateReadTime(b *testing.B, saveCount int) {

	db := setupDB(b)
	initialSetOfValidators := validators(100000)

	// Save a state to read in benchmark
	r := [32]byte{'A'}
	st, err := util.NewBeaconState()
	require.NoError(b, err)
	require.NoError(b, st.SetValidators(initialSetOfValidators))
	require.NoError(b, db.SaveState(b.Context(), st, r))

	// construct some states and save to randomize benchmark.
	for range saveCount {
		key := make([]byte, 32)
		_, err := rand.Read(key)
		require.NoError(b, err)
		st, err = util.NewBeaconState()
		require.NoError(b, err)

		// Add some more new validator to the base validator.
		validatosToAddInTest := validators(10000)
		allValidators := append(initialSetOfValidators, validatosToAddInTest...)

		// shuffle validators.
		mathRand.New(mathRand.NewSource(time.Now().UnixNano()))
		mathRand.Shuffle(len(allValidators), func(i, j int) { allValidators[i], allValidators[j] = allValidators[j], allValidators[i] })

		require.NoError(b, st.SetValidators(allValidators))
		require.NoError(b, db.SaveState(b.Context(), st, bytesutil.ToBytes32(key)))
	}

	b.ReportAllocs()

	for b.Loop() {
		_, err := db.State(b.Context(), r)
		require.NoError(b, err)
	}
}

func TestStateBellatrix_CanSaveRetrieveValidatorEntries(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, _ := util.DeterministicGenesisStateBellatrix(t, 20)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe(), "saved state with validators and retrieved state are not matching")

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)
}

func TestBellatrixState_CanSaveRetrieve(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateBellatrix(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe())

	savedS, err = db.State(t.Context(), [32]byte{'B'})
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestBellatrixState_CanDelete(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateBellatrix(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	require.NoError(t, db.DeleteState(t.Context(), r))
	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestBellatrixState_CanDeleteWithBlock(t *testing.T) {
	db := setupDB(t)

	b := util.NewBeaconBlockBellatrix()
	b.Block.Slot = 100
	r, err := b.Block.HashTreeRoot()
	require.NoError(t, err)
	wsb, err := blocks.NewSignedBeaconBlock(b)
	require.NoError(t, err)
	require.NoError(t, db.SaveBlock(t.Context(), wsb))

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateBellatrix(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	require.NoError(t, db.DeleteState(t.Context(), r))
	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestDenebState_CanSaveRetrieve(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateDeneb(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe())

	savedS, err = db.State(t.Context(), [32]byte{'B'})
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestDenebState_CanDelete(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateDeneb(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	require.NoError(t, db.DeleteState(t.Context(), r))
	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestStateDeneb_CanSaveRetrieveValidatorEntries(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, _ := util.DeterministicGenesisStateDeneb(t, 20)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.Validators(), savedS.Validators(), "saved state with validators and retrieved state are not matching")

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)
}

func TestElectraState_CanSaveRetrieve(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateElectra(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	assert.DeepSSZEqual(t, st.ToProtoUnsafe(), savedS.ToProtoUnsafe())

	savedS, err = db.State(t.Context(), [32]byte{'B'})
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestElectraState_CanDelete(t *testing.T) {
	db := setupDB(t)

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	st, _ := util.DeterministicGenesisStateElectra(t, 1)
	require.NoError(t, st.SetSlot(100))

	require.NoError(t, db.SaveState(t.Context(), st, r))
	require.Equal(t, true, db.HasState(t.Context(), r))

	require.NoError(t, db.DeleteState(t.Context(), r))
	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)
	require.Equal(t, state.ReadOnlyBeaconState(nil), savedS, "Unsaved state should've been nil")
}

func TestStateElectra_CanSaveRetrieveValidatorEntries(t *testing.T) {
	db := setupDB(t)

	// enable historical state representation flag to test this
	resetCfg := features.InitWithReset(&features.Flags{
		EnableHistoricalSpaceRepresentation: true,
	})
	defer resetCfg()

	r := [32]byte{'A'}

	require.Equal(t, false, db.HasState(t.Context(), r))

	stateValidators := validators(10)
	st, _ := util.DeterministicGenesisStateElectra(t, 20)
	require.NoError(t, st.SetSlot(100))
	require.NoError(t, st.SetValidators(stateValidators))

	ctx := t.Context()
	require.NoError(t, db.SaveState(ctx, st, r))
	assert.Equal(t, true, db.HasState(t.Context(), r))

	savedS, err := db.State(t.Context(), r)
	require.NoError(t, err)

	require.DeepSSZEqual(t, st.Validators(), savedS.Validators(), "saved state with validators and retrieved state are not matching")

	// check if the index of the second state is still present.
	err = db.db.Update(func(tx *bolt.Tx) error {
		idxBkt := tx.Bucket(blockRootValidatorHashesBucket)
		data := idxBkt.Get(r[:])
		require.NotEqual(t, 0, len(data))
		return nil
	})
	require.NoError(t, err)

	// check if all the validator entries are still intact in the validator entry bucket.
	err = db.db.Update(func(tx *bolt.Tx) error {
		valBkt := tx.Bucket(stateValidatorsBucket)
		// if any of the original validator entry is not present, then fail the test.
		for _, val := range stateValidators {
			hash, hashErr := val.HashTreeRoot()
			assert.NoError(t, hashErr)
			data := valBkt.Get(hash[:])
			require.NotNil(t, data)
			require.NotEqual(t, 0, len(data))
		}
		return nil
	})
	require.NoError(t, err)
}

func BenchmarkState_CheckStateSaveTime_1(b *testing.B)  { checkStateSaveTime(b, 1) }
func BenchmarkState_CheckStateSaveTime_10(b *testing.B) { checkStateSaveTime(b, 10) }

func BenchmarkState_CheckStateReadTime_1(b *testing.B)  { checkStateReadTime(b, 1) }
func BenchmarkState_CheckStateReadTime_10(b *testing.B) { checkStateReadTime(b, 10) }

func TestStore_CleanUpDirtyStates_NoOriginRoot(t *testing.T) {
	// This test verifies that CleanUpDirtyStates does not fail when the origin block root is not set,
	// which can happen when starting from genesis or in certain fork scenarios like Fulu.
	db := setupDB(t)
	genesisState, err := util.NewBeaconState()
	require.NoError(t, err)
	genesisRoot := [fieldparams.RootLength]byte{'a'}
	require.NoError(t, db.SaveGenesisBlockRoot(t.Context(), genesisRoot))
	require.NoError(t, db.SaveState(t.Context(), genesisState, genesisRoot))
	// Note: We intentionally do NOT call SaveOriginCheckpointBlockRoot here
	// to simulate the scenario where origin block root is not set
	slotsPerArchivedPoint := primitives.Slot(128)
	bRoots := make([][fieldparams.RootLength]byte, 0)
	prevRoot := genesisRoot
	for i := primitives.Slot(1); i <= slotsPerArchivedPoint; i++ { // skip slot 0
		b := util.NewBeaconBlock()
		b.Block.Slot = i
		b.Block.ParentRoot = prevRoot[:]
		r, err := b.Block.HashTreeRoot()
		require.NoError(t, err)
		wsb, err := blocks.NewSignedBeaconBlock(b)
		require.NoError(t, err)
		require.NoError(t, db.SaveBlock(t.Context(), wsb))
		bRoots = append(bRoots, r)
		prevRoot = r
		st, err := util.NewBeaconState()
		require.NoError(t, err)
		require.NoError(t, st.SetSlot(i))
		require.NoError(t, db.SaveState(t.Context(), st, r))
	}
	require.NoError(t, db.SaveFinalizedCheckpoint(t.Context(), &ethpb.Checkpoint{
		Root:  bRoots[len(bRoots)-1][:],
		Epoch: primitives.Epoch(slotsPerArchivedPoint / params.BeaconConfig().SlotsPerEpoch),
	}))
	// This should not fail even though origin block root is not set
	err = db.CleanUpDirtyStates(t.Context(), slotsPerArchivedPoint)
	require.NoError(t, err)
	// Verify that cleanup still works correctly
	for i, root := range bRoots {
		if primitives.Slot(i) >= slotsPerArchivedPoint.SubSlot(slotsPerArchivedPoint.Div(3)) {
			require.Equal(t, true, db.HasState(t.Context(), root))
		} else {
			require.Equal(t, false, db.HasState(t.Context(), root))
		}
	}
}

func TestStore_CanSaveRetrieveStateUsingStateDiff(t *testing.T) {
	t.Run("No state summary or block", func(t *testing.T) {
		db := setupDB(t)
		featCfg := &features.Flags{}
		featCfg.EnableStateDiff = true
		reset := features.InitWithReset(featCfg)
		defer reset()
		setDefaultStateDiffExponents()

		err := setOffsetInDB(db, 0)
		require.NoError(t, err)

		readSt, err := db.State(context.Background(), [32]byte{'A'})
		require.IsNil(t, readSt)
		require.ErrorContains(t, "neither state summary nor block found", err)
	})

	t.Run("Slot not in tree", func(t *testing.T) {
		db := setupDB(t)
		featCfg := &features.Flags{}
		featCfg.EnableStateDiff = true
		reset := features.InitWithReset(featCfg)
		defer reset()
		setDefaultStateDiffExponents()

		err := setOffsetInDB(db, 0)
		require.NoError(t, err)

		r := bytesutil.ToBytes32([]byte{'A'})
		ss := &ethpb.StateSummary{Slot: 1, Root: r[:]} // slot 1 not in tree
		err = db.SaveStateSummary(context.Background(), ss)
		require.NoError(t, err)

		readSt, err := db.State(context.Background(), r)
		require.ErrorContains(t, "slot not in tree", err)
		require.IsNil(t, readSt)

	})

	t.Run("State not found", func(t *testing.T) {
		db := setupDB(t)
		featCfg := &features.Flags{}
		featCfg.EnableStateDiff = true
		reset := features.InitWithReset(featCfg)
		defer reset()
		setDefaultStateDiffExponents()

		err := setOffsetInDB(db, 0)
		require.NoError(t, err)

		r := bytesutil.ToBytes32([]byte{'A'})
		ss := &ethpb.StateSummary{Slot: 32, Root: r[:]} // slot 32 is in tree
		err = db.SaveStateSummary(context.Background(), ss)
		require.NoError(t, err)

		readSt, err := db.State(context.Background(), r)
		require.ErrorContains(t, "state not found", err)
		require.IsNil(t, readSt)
	})

	t.Run("Full state snapshot", func(t *testing.T) {
		t.Run("using state summary", func(t *testing.T) {
			for v := range version.All() {
				t.Run(version.String(v), func(t *testing.T) {
					db := setupDB(t)
					featCfg := &features.Flags{}
					featCfg.EnableStateDiff = true
					reset := features.InitWithReset(featCfg)
					defer reset()
					setDefaultStateDiffExponents()

					err := setOffsetInDB(db, 0)
					require.NoError(t, err)

					st, _ := createState(t, 0, v)

					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					r := bytesutil.ToBytes32([]byte{'A'})
					ss := &ethpb.StateSummary{Slot: 0, Root: r[:]}
					err = db.SaveStateSummary(context.Background(), ss)
					require.NoError(t, err)

					readSt, err := db.State(context.Background(), r)
					require.NoError(t, err)
					require.NotNil(t, readSt)

					stSSZ, err := st.MarshalSSZ()
					require.NoError(t, err)
					readStSSZ, err := readSt.MarshalSSZ()
					require.NoError(t, err)
					require.DeepSSZEqual(t, stSSZ, readStSSZ)
				})
			}
		})

		t.Run("using block", func(t *testing.T) {
			for v := range version.All() {
				t.Run(version.String(v), func(t *testing.T) {
					db := setupDB(t)
					featCfg := &features.Flags{}
					featCfg.EnableStateDiff = true
					reset := features.InitWithReset(featCfg)
					defer reset()
					setDefaultStateDiffExponents()

					err := setOffsetInDB(db, 0)
					require.NoError(t, err)

					st, _ := createState(t, 0, v)

					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					blk := util.NewBeaconBlock()
					blk.Block.Slot = 0
					signedBlk, err := blocks.NewSignedBeaconBlock(blk)
					require.NoError(t, err)
					err = db.SaveBlock(context.Background(), signedBlk)
					require.NoError(t, err)
					r, err := signedBlk.Block().HashTreeRoot()
					require.NoError(t, err)

					readSt, err := db.State(context.Background(), r)
					require.NoError(t, err)
					require.NotNil(t, readSt)

					stSSZ, err := st.MarshalSSZ()
					require.NoError(t, err)
					readStSSZ, err := readSt.MarshalSSZ()
					require.NoError(t, err)
					require.DeepSSZEqual(t, stSSZ, readStSSZ)
				})
			}
		})
	})

	t.Run("Diffed state", func(t *testing.T) {
		t.Run("using state summary", func(t *testing.T) {
			for v := range version.All() {
				t.Run(version.String(v), func(t *testing.T) {
					db := setupDB(t)
					featCfg := &features.Flags{}
					featCfg.EnableStateDiff = true
					reset := features.InitWithReset(featCfg)
					defer reset()
					setDefaultStateDiffExponents()

					exponents := flags.Get().StateDiffExponents

					err := setOffsetInDB(db, 0)
					require.NoError(t, err)

					st, _ := createState(t, 0, v)
					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					slot := primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])))
					st, _ = createState(t, slot, v)
					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					slot = primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])) + math.PowerOf2(uint64(exponents[len(exponents)-1])))
					st, _ = createState(t, slot, v)
					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					r := bytesutil.ToBytes32([]byte{'A'})
					ss := &ethpb.StateSummary{Slot: slot, Root: r[:]}
					err = db.SaveStateSummary(context.Background(), ss)
					require.NoError(t, err)

					readSt, err := db.State(context.Background(), r)
					require.NoError(t, err)
					require.NotNil(t, readSt)

					stSSZ, err := st.MarshalSSZ()
					require.NoError(t, err)
					readStSSZ, err := readSt.MarshalSSZ()
					require.NoError(t, err)
					require.DeepSSZEqual(t, stSSZ, readStSSZ)
				})
			}
		})

		t.Run("using block", func(t *testing.T) {
			for v := range version.All() {
				t.Run(version.String(v), func(t *testing.T) {
					db := setupDB(t)
					featCfg := &features.Flags{}
					featCfg.EnableStateDiff = true
					reset := features.InitWithReset(featCfg)
					defer reset()
					setDefaultStateDiffExponents()

					exponents := flags.Get().StateDiffExponents

					err := setOffsetInDB(db, 0)
					require.NoError(t, err)

					st, _ := createState(t, 0, v)

					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					slot := primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])))
					st, _ = createState(t, slot, v)
					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					slot = primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])) + math.PowerOf2(uint64(exponents[len(exponents)-1])))
					st, _ = createState(t, slot, v)
					err = db.saveStateByDiff(context.Background(), st)
					require.NoError(t, err)

					blk := util.NewBeaconBlock()
					blk.Block.Slot = slot
					signedBlk, err := blocks.NewSignedBeaconBlock(blk)
					require.NoError(t, err)
					err = db.SaveBlock(context.Background(), signedBlk)
					require.NoError(t, err)
					r, err := signedBlk.Block().HashTreeRoot()
					require.NoError(t, err)

					readSt, err := db.State(context.Background(), r)
					require.NoError(t, err)
					require.NotNil(t, readSt)

					stSSZ, err := st.MarshalSSZ()
					require.NoError(t, err)
					readStSSZ, err := readSt.MarshalSSZ()
					require.NoError(t, err)
					require.DeepSSZEqual(t, stSSZ, readStSSZ)
				})
			}
		})
	})
}

func TestStore_HasStateUsingStateDiff(t *testing.T) {
	t.Run("No state summary or block", func(t *testing.T) {
		hook := logTest.NewGlobal()
		db := setupDB(t)
		featCfg := &features.Flags{}
		featCfg.EnableStateDiff = true
		reset := features.InitWithReset(featCfg)
		defer reset()
		setDefaultStateDiffExponents()

		err := setOffsetInDB(db, 0)
		require.NoError(t, err)

		hasSt := db.HasState(t.Context(), [32]byte{'A'})
		require.Equal(t, false, hasSt)
		require.LogsContain(t, hook, "neither state summary nor block found")
	})

	t.Run("slot in tree or not", func(t *testing.T) {
		db := setupDB(t)
		featCfg := &features.Flags{}
		featCfg.EnableStateDiff = true
		reset := features.InitWithReset(featCfg)
		defer reset()
		setDefaultStateDiffExponents()

		err := setOffsetInDB(db, 0)
		require.NoError(t, err)

		testCases := []struct {
			slot     primitives.Slot
			expected bool
		}{
			{slot: 1, expected: false},                                      // slot 1 not in tree
			{slot: 32, expected: true},                                      // slot 32 in tree
			{slot: 0, expected: true},                                       // slot 0 in tree
			{slot: primitives.Slot(math.PowerOf2(21)), expected: true},      // slot in tree
			{slot: primitives.Slot(math.PowerOf2(21) - 1), expected: false}, // slot not in tree
			{slot: primitives.Slot(math.PowerOf2(22)), expected: true},      // slot in tree
		}

		for _, tc := range testCases {
			r := bytesutil.ToBytes32([]byte{'A'})
			ss := &ethpb.StateSummary{Slot: tc.slot, Root: r[:]}
			err = db.SaveStateSummary(t.Context(), ss)
			require.NoError(t, err)

			hasSt := db.HasState(t.Context(), r)
			require.Equal(t, tc.expected, hasSt)
		}

	})
}
